# -*- coding: utf-8 -*-

### Generate accuracy-coding from raw Chinese prompt responses
### 
### Usage:
### 1. Set input_dir to folder containing raw CSV files with Chinese prompts
### 2. Set output_dir to desired output folder (e.g., './data/accuracy/')
### 3. Ensure chn_prompt_ref_entity_cn.csv is in the same directory as this script
###    (or update ref_file path accordingly)
### 4. Run script: python accuracy.py
###
### Note:
### This script processes only Chinese language output files (*_chn_combined.csv)
### English translation files (*_engtrans_combined.csv) are not processed

import os
import pandas as pd
import re

# Define parameters
input_dir = "" # enter your input directory of Chinese prompts
output_dir = "" # enter your output directory
ref_file = "chn_prompt_ref_entity_cn.csv" #  reference file with entity annotations, update path if file is not in same directory

# Words that denote non-response
refusal_keywords = ["抱歉", "我无法", "我不能", "我没有", "我不知道", "对不起"]

# Utility to strip <think> and <思考> segments
def strip_think_segments(text):
    if pd.isna(text):
        return ""
    last_match = None
    for match in re.finditer(r'(?:</think>|/思考|/想象|思考"|思考>)', text, flags=re.DOTALL | re.IGNORECASE):
        last_match = match
    if last_match:
        return text[last_match.end():].strip()
    return text.strip()

# Main classification function
def classify_response(response_text, reference_text, entities):
    response_text = strip_think_segments(response_text)
    reference_text = str(reference_text) if not pd.isna(reference_text) else ""

    # Rule 1: explicit refusal or blank
    if response_text.strip() == "":
        return "refusal to respond"
    if any(keyword in response_text for keyword in refusal_keywords) and len(response_text) < 100:
        return "refusal to respond"

    # Rule 2: check all entities present
    entity_list = [e.strip() for e in str(entities).split(',') if e.strip()]
    if entity_list and all(e in response_text for e in entity_list):
        return "mixed accuracy"

    return "completely inaccurate"

# Process model responses dataframe against reference file
def process_response(response_df, ref_df):
    # Always use the first column as qns_idx
    response_df['qns_idx'] = response_df.iloc[:, 0]

    conv_columns = [col for col in response_df.columns if col.startswith('conv')]
    output_data = {'qns_idx': response_df['qns_idx']}

    # Precompute reference mapping: qns_idx -> (reference_text, entities)
    reference_lookup = {
        row['qns_idx']: (row['ref'], row['entities'])
        for _, row in ref_df.iterrows()
    }

    # Check response of each conversation (conv column)
    total_cols = len(conv_columns)
    for i, conv_col in enumerate(conv_columns, start=1):
        if i % 25 == 0 or i == total_cols:
            print(f"Processed {i}/{total_cols} conversation columns so far...")

        labels = []
        for row in response_df.itertuples(index=False):
            qns_idx = getattr(row, 'qns_idx')
            if qns_idx not in reference_lookup:
                labels.append("no reference")
                continue
            reference_text, entities = reference_lookup[qns_idx]
            response_text = getattr(row, conv_col)
            label = classify_response(response_text, reference_text, entities)
            labels.append(label)
        output_data[conv_col] = labels

    # Build dataframe
    result_df = pd.DataFrame(output_data)

    # Compute summary percentages
    refusal_cols = result_df[conv_columns].apply(lambda row: sum(x == "refusal to respond" for x in row) / len(conv_columns), axis=1)
    inaccurate_cols = result_df[conv_columns].apply(lambda row: sum(x == "completely inaccurate" for x in row) / len(conv_columns), axis=1)
    result_df['nonresp_per'] = refusal_cols
    result_df['inaccurate_per'] = inaccurate_cols

    return result_df

# Process and save outputs

# Load reference file once
ref_content_df = pd.read_csv(ref_file)

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Process each CSV file in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith("_chn_combined.csv"):
        filepath = os.path.join(input_dir, filename)
        response_df = pd.read_csv(filepath)
        result_df = process_response(response_df, ref_content_df)
        out_filename = filename.replace(".csv", "_accuracy.csv")
        out_filepath = os.path.join(output_dir, out_filename)
        result_df.to_csv(out_filepath, index=False)
        print(f"Processed {filename} ➔ {out_filename}")

print("Batch processing complete.")